-
Notifications
You must be signed in to change notification settings - Fork 2.5k
Switch gradient checkpointing default to use_reentrant=False (PyTorch recommended) #4811
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fix! I have a couple of questions below.
Additionally, there are 2 new CI errors that I think we should check if they are caused by this PR and then if they should be fixed as well:
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
FAILED tests/experimental/test_nash_md_trainer.py::TestNashMDTrainer::test_training_pre_pefted_model_implicit_ref_with_reward_model - RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
FAILED tests/experimental/test_xpo_trainer.py::TestXPOTrainer::test_training_pre_pefted_model_implicit_ref - RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn…ore setting use_reentrant
… and XPO trainers
Fixed, |
Summary
Set
use_reentrant=Falseby defaultThis PR defaults gradient checkpointing to
use_reentrant=Falsein TRL when no explicit value is provided.PyTorch now recommends the non-reentrant checkpointing variant, see https://docs.pytorch.org/docs/stable/checkpoint.html. However, Transformers still defaults to
use_reentrant=Truebecause it was explicitly pinned in the past to silence a PyTorch warning during a transition period, and the default was never updated afterward.Until this is fixed upstream and released (see huggingface/transformers#43203), TRL aligns with the current PyTorch recommendation by setting
use_reentrant=Falseby default, while fully preserving any user-provided value.Fixes
Expected to mark a variable ready only onceThis PR fixes an issue that seems unrelated, but is: #4782
Remove
ScriptArguments.gradient_checkpointing_use_reentrantScriptArguments.gradient_checkpointing_use_reentrantexists but is never used. This is misleading, so this PR removes this argument.